// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <curand_kernel.h>
#include <cstdlib>
#include <string>
#include "helper.h"  // NOLINT

__device__ inline bool is_in(const int64_t *candidates,
                             const int64_t draft,
                             const int candidate_len) {
    for (int i = 0; i < candidate_len; i++) {
        if (draft == candidates[i]) {
            return true;
        }
    }
    return false;
}

static uint64_t seed = 0;
static uint64_t offset = 0;

__device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
                                        const float *candidate_scores,
                                        curandState_t *dev_curand_states,
                                        const int candidate_len,
                                        const float topp) {
    const int tid = threadIdx.x;

    float sum_scores = 0.0f;
    float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
    for (int i = 0; i < candidate_len; i++) {
        sum_scores += candidate_scores[i];
        if (rand_top_p <= sum_scores) {
            return candidate_ids[i];
        }
    }
    return candidate_ids[0];
}

__global__ void setup_kernel(curandState_t *state,
                             const uint64_t seed,
                             const uint64_t offset,
                             const int bs,
                             const bool need_batch_random) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
        if (need_batch_random) {
            curand_init(seed, i, offset, &state[i]);
        } else {
            curand_init(seed, 0, offset, &state[i]);
        }
    }
}

template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(int64_t *accept_tokens,
                                 int *accept_num,
                                 int64_t *step_idx,
                                 bool *stop_flags,
                                 const int *seq_lens_encoder,
                                 const int *seq_lens_decoder,
                                 const int64_t *draft_tokens,
                                 const int *actual_draft_token_nums,
                                 curandState_t *dev_curand_states,
                                 const float *topp,
                                 const int *seq_lens_this_time,
                                 const int64_t *verify_tokens,
                                 const float *verify_scores,
                                 const int64_t *max_dec_len,
                                 const int64_t *end_tokens,
                                 const bool *is_block_step,
                                 const int *output_cum_offsets,
                                 const int *actual_candidate_len,
                                 const int real_bsz,
                                 const int max_draft_tokens,
                                 const int end_length,
                                 const int max_seq_len,
                                 const int max_candidate_len,
                                 const int verify_window,
                                 const bool prefill_one_step_stop) {
    const int bid = threadIdx.x;
    const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
    // verify and set stop flags
    int accept_num_now = 1;
    int stop_flag_now_int = 0;

    if (!(is_block_step[bid] || bid >= real_bsz)) {
        // printf("bid %d\n", bid);

        if (stop_flags[bid]) {
            stop_flag_now_int = 1;
        } else {  // 这里prefill阶段也会进入，但是因为draft
                  // tokens会置零，因此会直接到最后的采样阶段
            auto *verify_tokens_now =
                verify_tokens + start_token_id * max_candidate_len;
            auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
            auto *actual_candidate_len_now =
                actual_candidate_len + start_token_id;

            int i = 0;
            // printf("seq_lens_this_time[%d]-1: %d \n",bid,
            // seq_lens_this_time[bid]-1);
            for (; i < seq_lens_this_time[bid] - 1; i++) {
                if (seq_lens_encoder[bid] != 0) {
                    break;
                }
                if (USE_TOPK) {
                    if (verify_tokens_now[i * max_candidate_len] ==
                        draft_tokens_now[i + 1]) {
                        // accept_num_now++;
                        step_idx[bid]++;
                        auto accept_token = draft_tokens_now[i + 1];
                        // printf("[USE_TOPK] bid %d Top 1 verify write accept
                        // %d is %lld\n", bid, i, accept_token);
                        accept_tokens[bid * max_draft_tokens + i] =
                            accept_token;
                        if (is_in_end(accept_token, end_tokens, end_length) ||
                            step_idx[bid] >= max_dec_len[bid]) {
                            stop_flags[bid] = true;
                            stop_flag_now_int = 1;
                            if (step_idx[bid] >= max_dec_len[bid])
                                accept_tokens[bid * max_draft_tokens + i] =
                                    end_tokens[0];
                            // printf("[USE_TOPK] bid %d Top 1 verify write
                            // accept %d is %lld\n", bid, i, accept_token);
                            break;
                        } else {
                            accept_num_now++;
                        }
                    } else {
                        break;
                    }
                } else {
                    auto actual_candidate_len_value =
                        actual_candidate_len_now[i] > max_candidate_len
                            ? max_candidate_len
                            : actual_candidate_len_now[i];
                    if (is_in(verify_tokens_now + i * max_candidate_len,
                              draft_tokens_now[i + 1],
                              actual_candidate_len_value)) {
                        // Top P verify
                        // accept_num_now++;
                        step_idx[bid]++;
                        auto accept_token = draft_tokens_now[i + 1];
                        accept_tokens[bid * max_draft_tokens + i] =
                            accept_token;

                        if (is_in_end(accept_token, end_tokens, end_length) ||
                            step_idx[bid] >= max_dec_len[bid]) {
                            stop_flags[bid] = true;
                            stop_flag_now_int = 1;
                            if (step_idx[bid] >= max_dec_len[bid])
                                accept_tokens[bid * max_draft_tokens + i] =
                                    end_tokens[0];
                            // printf("bid %d Top P verify write accept %d is
                            // %lld\n", bid, i, accept_token);
                            break;
                        } else {
                            accept_num_now++;
                        }
                    } else {
                        // TopK verify
                        int ii = i;
                        if (max_candidate_len >= 2 &&
                            verify_tokens_now[ii * max_candidate_len + 1] ==
                                draft_tokens_now[ii + 1]) {  // top-2
                            int j = 0;
                            ii += 1;
                            for (; j < verify_window &&
                                   ii < seq_lens_this_time[bid] - 1;
                                 j++, ii++) {
                                if (verify_tokens_now[ii * max_candidate_len] !=
                                    draft_tokens_now[ii + 1]) {
                                    break;
                                }
                            }
                            if (j >= verify_window) {  // accept all
                                accept_num_now += verify_window + 1;
                                step_idx[bid] += verify_window + 1;
                                for (; i < ii; i++) {
                                    auto accept_token = draft_tokens_now[i + 1];
                                    accept_tokens[bid * max_draft_tokens + i] =
                                        accept_token;
                                    // printf(
                                    //     "bid %d TopK verify write accept %d
                                    //     is "
                                    //     "%lld\n",
                                    //     bid,
                                    //     i,
                                    //     accept_token);
                                    if (is_in_end(accept_token,
                                                  end_tokens,
                                                  end_length) ||
                                        step_idx[bid] >= max_dec_len[bid]) {
                                        stop_flags[bid] = true;
                                        stop_flag_now_int = 1;
                                        if (step_idx[bid] >= max_dec_len[bid])
                                            accept_tokens[bid *
                                                              max_draft_tokens +
                                                          i] = end_tokens[0];
                                        // printf("bid %d TopK verify write
                                        // accept %d is %lld\n", bid, i,
                                        // end_tokens[0]);
                                        accept_num_now--;
                                        step_idx[bid]--;
                                        break;
                                    }
                                }
                            }
                        }
                        break;
                    }
                }
            }
            // sampling阶段
            // 第一种，draft_token[i+1]被拒绝，需要从verify_tokens_now[i]中选一个
            // 第二种，i == seq_lens_this_time[bid]-1,
            // 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
            if (!stop_flag_now_int) {
                int64_t accept_token;
                const float *verify_scores_now =
                    verify_scores + start_token_id * max_candidate_len;
                    step_idx[bid]++;
                if (ENABLE_TOPP) {
                    auto actual_candidate_len_value =
                        actual_candidate_len_now[i] > max_candidate_len
                            ? max_candidate_len
                            : actual_candidate_len_now[i];

                    accept_token = topp_sampling_kernel(
                        verify_tokens_now + i * max_candidate_len,
                        verify_scores_now + i * max_candidate_len,
                        dev_curand_states,
                        actual_candidate_len_value,
                        topp[bid]);
                } else {
                    accept_token = verify_tokens_now[i * max_candidate_len];
                }
                accept_tokens[bid * max_draft_tokens + i] = accept_token;
                if (prefill_one_step_stop) {
                    stop_flags[bid] = true;
                }
                if (is_in_end(accept_token, end_tokens, end_length) ||
                    step_idx[bid] >= max_dec_len[bid]) {
                    stop_flags[bid] = true;
                    stop_flag_now_int = 1;
                    if (step_idx[bid] >= max_dec_len[bid])
                        accept_tokens[bid * max_draft_tokens + i] =
                            end_tokens[0];
                }
            }
            accept_num[bid] = accept_num_now;
        }
    }
}

void SpeculateVerify(const paddle::Tensor &accept_tokens,
                     const paddle::Tensor &accept_num,
                     const paddle::Tensor &step_idx,
                     const paddle::Tensor &stop_flags,
                     const paddle::Tensor &seq_lens_encoder,
                     const paddle::Tensor &seq_lens_decoder,
                     const paddle::Tensor &draft_tokens,
                     const paddle::Tensor &seq_lens_this_time,
                     const paddle::Tensor &verify_tokens,
                     const paddle::Tensor &verify_scores,
                     const paddle::Tensor &max_dec_len,
                     const paddle::Tensor &end_tokens,
                     const paddle::Tensor &is_block_step,
                     const paddle::Tensor &output_cum_offsets,
                     const paddle::Tensor &actual_candidate_len,
                     const paddle::Tensor &actual_draft_token_nums,
                     const paddle::Tensor &topp,
                     int max_seq_len,
                     int verify_window,
                     bool enable_topp) {
    //   printf("Enter speculate update\n");
    auto bsz = accept_tokens.shape()[0];
    int real_bsz = seq_lens_this_time.shape()[0];
    auto max_draft_tokens = draft_tokens.shape()[1];
    auto end_length = end_tokens.shape()[0];
    auto max_candidate_len = verify_tokens.shape()[1];

    constexpr int BlockSize = 512;

    curandState_t *dev_curand_states;
    cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
    setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
        dev_curand_states, seed, offset, bsz, true);
    seed++;
    offset++;

    auto err = cudaDeviceSynchronize();
    if (err != 0) {
        printf("err %d\n", err);
    }

    err = cudaGetLastError();

    if (err != 0) {
        printf("err %d\n", err);
    }

    //   printf("inited curand\n");
    bool use_topk = false;
    char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
    if (env_var) {
        use_topk = static_cast<bool>(std::stoi(env_var));
    }
    bool prefill_one_step_stop = false;
    if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
        // std::cout << "Your PATH is: " << env_p << '\n';
        if (env_p[0] == '1') {
            prefill_one_step_stop = true;
        }
    }
    if (use_topk) {
        // printf("use_topk \n");
        if (enable_topp) {
            speculate_verify<true, true>
                <<<1, BlockSize, 0, accept_tokens.stream()>>>(
                    const_cast<int64_t *>(accept_tokens.data<int64_t>()),
                    const_cast<int *>(accept_num.data<int>()),
                    const_cast<int64_t *>(step_idx.data<int64_t>()),
                    const_cast<bool *>(stop_flags.data<bool>()),
                    seq_lens_encoder.data<int>(),
                    seq_lens_decoder.data<int>(),
                    draft_tokens.data<int64_t>(),
                    actual_draft_token_nums.data<int>(),
                    dev_curand_states,
                    topp.data<float>(),
                    seq_lens_this_time.data<int>(),
                    verify_tokens.data<int64_t>(),
                    verify_scores.data<float>(),
                    max_dec_len.data<int64_t>(),
                    end_tokens.data<int64_t>(),
                    is_block_step.data<bool>(),
                    output_cum_offsets.data<int>(),
                    actual_candidate_len.data<int>(),
                    real_bsz,
                    max_draft_tokens,
                    end_length,
                    max_seq_len,
                    max_candidate_len,
                    verify_window,
                    prefill_one_step_stop);
        } else {
            speculate_verify<false, true>
                <<<1, BlockSize, 0, accept_tokens.stream()>>>(
                    const_cast<int64_t *>(accept_tokens.data<int64_t>()),
                    const_cast<int *>(accept_num.data<int>()),
                    const_cast<int64_t *>(step_idx.data<int64_t>()),
                    const_cast<bool *>(stop_flags.data<bool>()),
                    seq_lens_encoder.data<int>(),
                    seq_lens_decoder.data<int>(),
                    draft_tokens.data<int64_t>(),
                    actual_draft_token_nums.data<int>(),
                    dev_curand_states,
                    topp.data<float>(),
                    seq_lens_this_time.data<int>(),
                    verify_tokens.data<int64_t>(),
                    verify_scores.data<float>(),
                    max_dec_len.data<int64_t>(),
                    end_tokens.data<int64_t>(),
                    is_block_step.data<bool>(),
                    output_cum_offsets.data<int>(),
                    actual_candidate_len.data<int>(),
                    real_bsz,
                    max_draft_tokens,
                    end_length,
                    max_seq_len,
                    max_candidate_len,
                    verify_window,
                    prefill_one_step_stop);
        }
    } else {
        if (enable_topp) {
            speculate_verify<true, false>
                <<<1, BlockSize, 0, accept_tokens.stream()>>>(
                    const_cast<int64_t *>(accept_tokens.data<int64_t>()),
                    const_cast<int *>(accept_num.data<int>()),
                    const_cast<int64_t *>(step_idx.data<int64_t>()),
                    const_cast<bool *>(stop_flags.data<bool>()),
                    seq_lens_encoder.data<int>(),
                    seq_lens_decoder.data<int>(),
                    draft_tokens.data<int64_t>(),
                    actual_draft_token_nums.data<int>(),
                    dev_curand_states,
                    topp.data<float>(),
                    seq_lens_this_time.data<int>(),
                    verify_tokens.data<int64_t>(),
                    verify_scores.data<float>(),
                    max_dec_len.data<int64_t>(),
                    end_tokens.data<int64_t>(),
                    is_block_step.data<bool>(),
                    output_cum_offsets.data<int>(),
                    actual_candidate_len.data<int>(),
                    real_bsz,
                    max_draft_tokens,
                    end_length,
                    max_seq_len,
                    max_candidate_len,
                    verify_window,
                    prefill_one_step_stop);
        } else {
            speculate_verify<false, false>
                <<<1, BlockSize, 0, accept_tokens.stream()>>>(
                    const_cast<int64_t *>(accept_tokens.data<int64_t>()),
                    const_cast<int *>(accept_num.data<int>()),
                    const_cast<int64_t *>(step_idx.data<int64_t>()),
                    const_cast<bool *>(stop_flags.data<bool>()),
                    seq_lens_encoder.data<int>(),
                    seq_lens_decoder.data<int>(),
                    draft_tokens.data<int64_t>(),
                    actual_draft_token_nums.data<int>(),
                    dev_curand_states,
                    topp.data<float>(),
                    seq_lens_this_time.data<int>(),
                    verify_tokens.data<int64_t>(),
                    verify_scores.data<float>(),
                    max_dec_len.data<int64_t>(),
                    end_tokens.data<int64_t>(),
                    is_block_step.data<bool>(),
                    output_cum_offsets.data<int>(),
                    actual_candidate_len.data<int>(),
                    real_bsz,
                    max_draft_tokens,
                    end_length,
                    max_seq_len,
                    max_candidate_len,
                    verify_window,
                    prefill_one_step_stop);
        }
    }

    cudaFree(dev_curand_states);
}

PD_BUILD_STATIC_OP(speculate_verify)
    .Inputs({"accept_tokens",
             "accept_num",
             "step_idx",
             "seq_lens_encoder",
             "seq_lens_decoder",
             "stop_flags",
             "draft_tokens",
             "seq_lens_this_time",
             "verify_tokens",
             "verify_scores",
             "max_dec_len",
             "end_tokens",
             "is_block_step",
             "output_cum_offsets",
             "actual_candidate_len",
             "actual_draft_token_nums",
             "topp"})
    .Outputs({"accept_tokens_out",
              "accept_num_out",
              "step_idx_out",
              "stop_flags_out"})
    .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
    .SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
                    {"accept_num", "accept_num_out"},
                    {"step_idx", "step_idx_out"},
                    {"stop_flags", "stop_flags_out"}})
    .SetKernelFn(PD_KERNEL(SpeculateVerify));
